import torch
import numpy as np

from dada.model.torch_model import TorchModel


class TorchNormModel(TorchModel):
    def __init__(self, n_features: int, p: int, init_point: torch.Tensor = None):
        self.p = p
        super().__init__(n_features, init_point)

    def loss(self):
        return (1 / self.p) * torch.pow(torch.norm(self.x), self.p)

    def compute_value(self, point: np.ndarray):
        return (1 / self.p) * np.linalg.norm(point) ** self.p
